-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Diffusion trainer fix: shift logits to align with input tokens #3191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughIntroduces a new utility to shift logits to input positions and updates diffusion generation and training to apply this shift before token selection and loss computation. No public API signatures changed. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Pre-merge checks and finishing touches✅ Passed checks (3 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/integrations/diffusion/utils.py (1)
162-166: Implementation looks correct for logits alignment.The function correctly shifts next-token prediction logits to align with input token positions by:
- Preserving the first logit position unchanged
- Shifting remaining logits left by one position
- Properly handling edge case of single-token sequences
The implementation aligns with the PR objective of adapting pretrained autoregressive models for diffusion fine-tuning.
However, consider adding a brief example in the docstring to clarify the transformation:
- """Align next-token logits with their input token positions for diffusion.""" + """Align next-token logits with their input token positions for diffusion. + + Example: [logit_for_pos1, logit_for_pos2, logit_for_pos3] + becomes: [logit_for_pos1, logit_for_pos1, logit_for_pos2] + """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/integrations/diffusion/generation.py(2 hunks)src/axolotl/integrations/diffusion/trainer.py(2 hunks)src/axolotl/integrations/diffusion/utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/integrations/diffusion/trainer.py (1)
src/axolotl/integrations/diffusion/utils.py (2)
create_bidirectional_attention_mask(125-159)shift_logits_to_input_positions(162-166)
src/axolotl/integrations/diffusion/generation.py (1)
src/axolotl/integrations/diffusion/utils.py (2)
create_bidirectional_attention_mask(125-159)shift_logits_to_input_positions(162-166)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (4)
src/axolotl/integrations/diffusion/trainer.py (2)
14-14: LGTM: Import addition is correct.The import of
shift_logits_to_input_positionsfrom the utils module follows the existing import pattern.
210-210: No issues detected with logits shifting alignment or loss computation.src/axolotl/integrations/diffusion/generation.py (2)
10-10: LGTM: Import addition is consistent.The import follows the same pattern as in the trainer module for consistency.
363-363: Logits shifting applied consistently in generation.The shift is correctly applied before token sampling in the diffusion step, maintaining consistency with the training logic. This ensures the same logits alignment is used during both training and generation.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test fail unrelated: out of space
| """Align next-token logits with their input token positions for diffusion.""" | ||
| if logits.size(1) <= 1: | ||
| return logits | ||
| return torch.cat([logits[:, :1], logits[:, :-1]], dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this trying to do? Concat logit's first column and 1..N column together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit of a hack to use pretrained causal LMs for diffusion fine-tuning. we're shifting logits to the right by one position so we align the input logits with the output logits
Unfortunately we're duplicating the first token, but I couldn't think of a better way to do it. open to ideas here
da80beb to
7f6f08e
Compare
|
just as a data point, I had messed around with an early version of Dan's diffusion trainer a while back and here's the change I made to support next-token prediction cf8c93e. my changes may be unnecessary, but wanted to make sure we didn't miss anything. |
Description
Title.
Motivation and Context
Pretrained autoregressive models treat the output logits as right-shifted by one. By doing this, we should be able to use pretrained AR models effectively for diffusion model fine-tuning!
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit